!--------------------------------------------------------------------------------------------------!
!   CP2K: A general program to perform molecular dynamics simulations                              !
!   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
!                                                                                                  !
!   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
!--------------------------------------------------------------------------------------------------!

! **************************************************************************************************
!> \brief module that contains the algorithms to perform an iterative
!>         diagonalization by the block-Davidson approach
!>         P. Blaha, et al J. Comp. Physics, 229, (2010), 453-460
!>         Iterative diagonalization in augmented plane wave based
!>         methods in electronic structure calculations
!> \par History
!>      05.2011 created [MI]
!> \author MI
! **************************************************************************************************
MODULE qs_scf_block_davidson

   USE cp_dbcsr_api,                    ONLY: &
        dbcsr_add, dbcsr_copy, dbcsr_create, dbcsr_get_info, dbcsr_init_p, &
        dbcsr_iterator_blocks_left, dbcsr_iterator_next_block, dbcsr_iterator_start, &
        dbcsr_iterator_stop, dbcsr_iterator_type, dbcsr_multiply, dbcsr_release_p, dbcsr_type, &
        dbcsr_type_no_symmetry, dbcsr_type_symmetric
   USE cp_dbcsr_contrib,                ONLY: dbcsr_get_diag,&
                                              dbcsr_scale_by_vector
   USE cp_dbcsr_operations,             ONLY: copy_dbcsr_to_fm,&
                                              copy_fm_to_dbcsr,&
                                              cp_dbcsr_m_by_n_from_row_template,&
                                              cp_dbcsr_m_by_n_from_template,&
                                              cp_dbcsr_sm_fm_multiply
   USE cp_fm_basic_linalg,              ONLY: cp_fm_column_scale,&
                                              cp_fm_scale_and_add,&
                                              cp_fm_symm,&
                                              cp_fm_transpose,&
                                              cp_fm_triangular_invert,&
                                              cp_fm_uplo_to_full
   USE cp_fm_cholesky,                  ONLY: cp_fm_cholesky_decompose,&
                                              cp_fm_cholesky_restore
   USE cp_fm_diag,                      ONLY: choose_eigv_solver,&
                                              cp_fm_power
   USE cp_fm_struct,                    ONLY: cp_fm_struct_create,&
                                              cp_fm_struct_release,&
                                              cp_fm_struct_type
   USE cp_fm_types,                     ONLY: cp_fm_create,&
                                              cp_fm_get_diag,&
                                              cp_fm_release,&
                                              cp_fm_set_all,&
                                              cp_fm_to_fm,&
                                              cp_fm_to_fm_submat,&
                                              cp_fm_type,&
                                              cp_fm_vectorsnorm
   USE kinds,                           ONLY: dp
   USE machine,                         ONLY: m_walltime
   USE message_passing,                 ONLY: mp_comm_type
   USE parallel_gemm_api,               ONLY: parallel_gemm
   USE preconditioner,                  ONLY: apply_preconditioner
   USE preconditioner_types,            ONLY: preconditioner_type
   USE qs_block_davidson_types,         ONLY: davidson_type
   USE qs_mo_types,                     ONLY: get_mo_set,&
                                              mo_set_type
#include "./base/base_uses.f90"

   IMPLICIT NONE
   PRIVATE
   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'qs_scf_block_davidson'

   PUBLIC :: generate_extended_space, generate_extended_space_sparse

CONTAINS

! **************************************************************************************************
!> \brief ...
!> \param bdav_env ...
!> \param mo_set ...
!> \param matrix_h ...
!> \param matrix_s ...
!> \param output_unit ...
!> \param preconditioner ...
! **************************************************************************************************
   SUBROUTINE generate_extended_space(bdav_env, mo_set, matrix_h, matrix_s, output_unit, &
                                      preconditioner)

      TYPE(davidson_type)                                :: bdav_env
      TYPE(mo_set_type), INTENT(IN)                      :: mo_set
      TYPE(dbcsr_type), POINTER                          :: matrix_h, matrix_s
      INTEGER, INTENT(IN)                                :: output_unit
      TYPE(preconditioner_type), OPTIONAL, POINTER       :: preconditioner

      CHARACTER(len=*), PARAMETER :: routineN = 'generate_extended_space'

      INTEGER :: handle, homo, i_first, i_last, imo, iter, j, jj, max_iter, n, nao, nmat, nmat2, &
         nmo, nmo_converged, nmo_not_converged, nset, nset_conv, nset_not_conv
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: iconv, inotconv
      INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: iconv_set, inotconv_set
      LOGICAL                                            :: converged, do_apply_preconditioner
      REAL(dp)                                           :: lambda, max_norm, min_norm, t1, t2
      REAL(dp), ALLOCATABLE, DIMENSION(:)                :: ritz_coeff, vnorm
      REAL(dp), DIMENSION(:), POINTER                    :: eig_not_conv, eigenvalues, evals
      TYPE(cp_fm_struct_type), POINTER                   :: fm_struct_tmp
      TYPE(cp_fm_type)                                   :: c_conv, c_notconv, c_out, h_block, h_fm, &
                                                            m_hc, m_sc, m_tmp, mt_tmp, s_block, &
                                                            s_fm, v_block, w_block
      TYPE(cp_fm_type), POINTER                          :: c_pz, c_z, mo_coeff
      TYPE(dbcsr_type), POINTER                          :: mo_coeff_b

      CALL timeset(routineN, handle)

      NULLIFY (mo_coeff, mo_coeff_b, eigenvalues)

      do_apply_preconditioner = .FALSE.
      IF (PRESENT(preconditioner)) do_apply_preconditioner = .TRUE.
      CALL get_mo_set(mo_set=mo_set, mo_coeff=mo_coeff, mo_coeff_b=mo_coeff_b, eigenvalues=eigenvalues, &
                      nao=nao, nmo=nmo, homo=homo)
      IF (do_apply_preconditioner) THEN
         max_iter = bdav_env%max_iter
      ELSE
         max_iter = 1
      END IF

      NULLIFY (c_z, c_pz)
      NULLIFY (evals, eig_not_conv)
      t1 = m_walltime()
      IF (output_unit > 0) THEN
         WRITE (output_unit, "(T15,A,T23,A,T36,A,T49,A,T60,A,/,T8,A)") &
            " Cycle ", " conv. MOS ", " B2MAX ", " B2MIN ", " Time", REPEAT("-", 60)
      END IF

      ALLOCATE (iconv(nmo))
      ALLOCATE (inotconv(nmo))
      ALLOCATE (ritz_coeff(nmo))
      ALLOCATE (vnorm(nmo))

      converged = .FALSE.
      DO iter = 1, max_iter

         ! compute Ritz values
         ritz_coeff = 0.0_dp
         CALL cp_fm_create(m_hc, mo_coeff%matrix_struct, name="hc")
         CALL cp_dbcsr_sm_fm_multiply(matrix_h, mo_coeff, m_hc, nmo)
         CALL cp_fm_create(m_sc, mo_coeff%matrix_struct, name="sc")
         CALL cp_dbcsr_sm_fm_multiply(matrix_s, mo_coeff, m_sc, nmo)

         CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=nmo, ncol_global=nmo, &
                                  context=mo_coeff%matrix_struct%context, &
                                  para_env=mo_coeff%matrix_struct%para_env)
         CALL cp_fm_create(m_tmp, fm_struct_tmp, name="matrix_tmp")
         CALL cp_fm_struct_release(fm_struct_tmp)

         CALL parallel_gemm('T', 'N', nmo, nmo, nao, 1.0_dp, mo_coeff, m_hc, 0.0_dp, m_tmp)
         CALL cp_fm_get_diag(m_tmp, ritz_coeff)
         CALL cp_fm_release(m_tmp)

         ! Check for converged eigenvectors
         c_z => bdav_env%matrix_z
         c_pz => bdav_env%matrix_pz
         CALL cp_fm_to_fm(m_sc, c_z)
         CALL cp_fm_column_scale(c_z, ritz_coeff)
         CALL cp_fm_scale_and_add(-1.0_dp, c_z, 1.0_dp, m_hc)
         CALL cp_fm_vectorsnorm(c_z, vnorm)

         nmo_converged = 0
         nmo_not_converged = 0
         max_norm = 0.0_dp
         min_norm = 1.e10_dp
         DO imo = 1, nmo
            max_norm = MAX(max_norm, vnorm(imo))
            min_norm = MIN(min_norm, vnorm(imo))
         END DO
         iconv = 0
         inotconv = 0
         DO imo = 1, nmo
            IF (vnorm(imo) <= bdav_env%eps_iter) THEN
               nmo_converged = nmo_converged + 1
               iconv(nmo_converged) = imo
            ELSE
               nmo_not_converged = nmo_not_converged + 1
               inotconv(nmo_not_converged) = imo
            END IF
         END DO

         IF (nmo_converged > 0) THEN
            ALLOCATE (iconv_set(nmo_converged, 2))
            ALLOCATE (inotconv_set(nmo_not_converged, 2))
            i_last = iconv(1)
            nset = 0
            DO j = 1, nmo_converged
               imo = iconv(j)

               IF (imo == i_last + 1) THEN
                  i_last = imo
                  iconv_set(nset, 2) = imo
               ELSE
                  i_last = imo
                  nset = nset + 1
                  iconv_set(nset, 1) = imo
                  iconv_set(nset, 2) = imo
               END IF
            END DO
            nset_conv = nset

            i_last = inotconv(1)
            nset = 0
            DO j = 1, nmo_not_converged
               imo = inotconv(j)

               IF (imo == i_last + 1) THEN
                  i_last = imo
                  inotconv_set(nset, 2) = imo
               ELSE
                  i_last = imo
                  nset = nset + 1
                  inotconv_set(nset, 1) = imo
                  inotconv_set(nset, 2) = imo
               END IF
            END DO
            nset_not_conv = nset
            CALL cp_fm_release(m_sc)
            CALL cp_fm_release(m_hc)
            NULLIFY (c_z, c_pz)
         END IF

         IF (REAL(nmo_converged, dp)/REAL(nmo, dp) > bdav_env%conv_percent) THEN
            converged = .TRUE.
            DEALLOCATE (iconv_set)
            DEALLOCATE (inotconv_set)
            t2 = m_walltime()
            IF (output_unit > 0) THEN
               WRITE (output_unit, '(T16,I5,T24,I6,T33,E12.4,2x,E12.4,T60,F8.3)') &
                  iter, nmo_converged, max_norm, min_norm, t2 - t1

               WRITE (output_unit, *) " Reached convergence in ", iter, &
                  " Davidson iterations"
            END IF

            EXIT
         END IF

         IF (nmo_converged > 0) THEN
            CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=nao, ncol_global=nao, &
                                     context=mo_coeff%matrix_struct%context, &
                                     para_env=mo_coeff%matrix_struct%para_env)
            !allocate h_fm
            CALL cp_fm_create(h_fm, fm_struct_tmp, name="matrix_tmp")
            !allocate s_fm
            CALL cp_fm_create(s_fm, fm_struct_tmp, name="matrix_tmp")
            !copy matrix_h in h_fm
            CALL copy_dbcsr_to_fm(matrix_h, h_fm)
            CALL cp_fm_uplo_to_full(h_fm, s_fm)

            !copy matrix_s in s_fm
!        CALL cp_fm_set_all(s_fm,0.0_dp)
            CALL copy_dbcsr_to_fm(matrix_s, s_fm)

            !allocate c_out
            CALL cp_fm_create(c_out, fm_struct_tmp, name="matrix_tmp")
            ! set c_out to zero
            CALL cp_fm_set_all(c_out, 0.0_dp)
            CALL cp_fm_struct_release(fm_struct_tmp)

            !allocate c_conv
            CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=nao, ncol_global=nmo_converged, &
                                     context=mo_coeff%matrix_struct%context, &
                                     para_env=mo_coeff%matrix_struct%para_env)
            CALL cp_fm_create(c_conv, fm_struct_tmp, name="c_conv")
            CALL cp_fm_set_all(c_conv, 0.0_dp)
            !allocate m_tmp
            CALL cp_fm_create(m_tmp, fm_struct_tmp, name="m_tmp_nxmc")
            CALL cp_fm_struct_release(fm_struct_tmp)
         END IF

         !allocate c_notconv
         CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=nao, ncol_global=nmo_not_converged, &
                                  context=mo_coeff%matrix_struct%context, &
                                  para_env=mo_coeff%matrix_struct%para_env)
         CALL cp_fm_create(c_notconv, fm_struct_tmp, name="c_notconv")
         CALL cp_fm_set_all(c_notconv, 0.0_dp)
         IF (nmo_converged > 0) THEN
            CALL cp_fm_create(m_hc, fm_struct_tmp, name="m_hc")
            CALL cp_fm_create(m_sc, fm_struct_tmp, name="m_sc")
            !allocate c_z
            ALLOCATE (c_z, c_pz)
            CALL cp_fm_create(c_z, fm_struct_tmp, name="c_z")
            CALL cp_fm_create(c_pz, fm_struct_tmp, name="c_pz")
            CALL cp_fm_set_all(c_z, 0.0_dp)

            ! sum contributions to c_out
            jj = 1
            DO j = 1, nset_conv
               i_first = iconv_set(j, 1)
               i_last = iconv_set(j, 2)
               n = i_last - i_first + 1
               CALL cp_fm_to_fm_submat(mo_coeff, c_conv, nao, n, 1, i_first, 1, jj)
               jj = jj + n
            END DO
            CALL cp_fm_symm('L', 'U', nao, nmo_converged, 1.0_dp, s_fm, c_conv, 0.0_dp, m_tmp)
            CALL parallel_gemm('N', 'T', nao, nao, nmo_converged, 1.0_dp, m_tmp, m_tmp, 0.0_dp, c_out)

            ! project c_out out of H
            lambda = 100.0_dp*ABS(eigenvalues(homo))
            CALL cp_fm_scale_and_add(lambda, c_out, 1.0_dp, h_fm)
            CALL cp_fm_release(m_tmp)
            CALL cp_fm_release(h_fm)

         END IF

         !allocate m_tmp
         CALL cp_fm_create(m_tmp, fm_struct_tmp, name="m_tmp_nxm")
         CALL cp_fm_struct_release(fm_struct_tmp)
         IF (nmo_converged > 0) THEN
            ALLOCATE (eig_not_conv(nmo_not_converged))
            jj = 1
            DO j = 1, nset_not_conv
               i_first = inotconv_set(j, 1)
               i_last = inotconv_set(j, 2)
               n = i_last - i_first + 1
               CALL cp_fm_to_fm_submat(mo_coeff, c_notconv, nao, n, 1, i_first, 1, jj)
               eig_not_conv(jj:jj + n - 1) = ritz_coeff(i_first:i_last)
               jj = jj + n
            END DO
            CALL parallel_gemm('N', 'N', nao, nmo_not_converged, nao, 1.0_dp, c_out, c_notconv, 0.0_dp, m_hc)
            CALL cp_fm_symm('L', 'U', nao, nmo_not_converged, 1.0_dp, s_fm, c_notconv, 0.0_dp, m_sc)
            ! extend suspace using only the not converged vectors
            CALL cp_fm_to_fm(m_sc, m_tmp)
            CALL cp_fm_column_scale(m_tmp, eig_not_conv)
            CALL cp_fm_scale_and_add(-1.0_dp, m_tmp, 1.0_dp, m_hc)
            DEALLOCATE (eig_not_conv)
            CALL cp_fm_to_fm(m_tmp, c_z)
         ELSE
            CALL cp_fm_to_fm(mo_coeff, c_notconv)
         END IF

         !preconditioner
         IF (do_apply_preconditioner) THEN
            IF (preconditioner%in_use /= 0) THEN
               CALL apply_preconditioner(preconditioner, c_z, c_pz)
            ELSE
               CALL cp_fm_to_fm(c_z, c_pz)
            END IF
         ELSE
            CALL cp_fm_to_fm(c_z, c_pz)
         END IF
         CALL cp_fm_release(m_tmp)

         CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=nmo_not_converged, ncol_global=nmo_not_converged, &
                                  context=mo_coeff%matrix_struct%context, &
                                  para_env=mo_coeff%matrix_struct%para_env)

         CALL cp_fm_create(m_tmp, fm_struct_tmp, name="m_tmp_mxm")
         CALL cp_fm_create(mt_tmp, fm_struct_tmp, name="mt_tmp_mxm")
         CALL cp_fm_struct_release(fm_struct_tmp)

         nmat = nmo_not_converged
         nmat2 = 2*nmo_not_converged
         CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=nmat2, ncol_global=nmat2, &
                                  context=mo_coeff%matrix_struct%context, &
                                  para_env=mo_coeff%matrix_struct%para_env)

         CALL cp_fm_create(s_block, fm_struct_tmp, name="sb")
         CALL cp_fm_create(h_block, fm_struct_tmp, name="hb")
         CALL cp_fm_create(v_block, fm_struct_tmp, name="vb")
         CALL cp_fm_create(w_block, fm_struct_tmp, name="wb")
         ALLOCATE (evals(nmat2))

         CALL cp_fm_struct_release(fm_struct_tmp)

         ! compute CSC
         CALL cp_fm_set_all(s_block, 0.0_dp, 1.0_dp)

         ! compute CHC
         CALL parallel_gemm('T', 'N', nmat, nmat, nao, 1.0_dp, c_notconv, m_hc, 0.0_dp, m_tmp)
         CALL cp_fm_to_fm_submat(m_tmp, h_block, nmat, nmat, 1, 1, 1, 1)

         ! compute ZSC
         CALL parallel_gemm('T', 'N', nmat, nmat, nao, 1.0_dp, c_pz, m_sc, 0.0_dp, m_tmp)
         CALL cp_fm_to_fm_submat(m_tmp, s_block, nmat, nmat, 1, 1, 1 + nmat, 1)
         CALL cp_fm_transpose(m_tmp, mt_tmp)
         CALL cp_fm_to_fm_submat(mt_tmp, s_block, nmat, nmat, 1, 1, 1, 1 + nmat)
         ! compute ZHC
         CALL parallel_gemm('T', 'N', nmat, nmat, nao, 1.0_dp, c_pz, m_hc, 0.0_dp, m_tmp)
         CALL cp_fm_to_fm_submat(m_tmp, h_block, nmat, nmat, 1, 1, 1 + nmat, 1)
         CALL cp_fm_transpose(m_tmp, mt_tmp)
         CALL cp_fm_to_fm_submat(mt_tmp, h_block, nmat, nmat, 1, 1, 1, 1 + nmat)

         CALL cp_fm_release(mt_tmp)

         ! reuse m_sc and m_hc to computr HZ and SZ
         IF (nmo_converged > 0) THEN
            CALL parallel_gemm('N', 'N', nao, nmat, nao, 1.0_dp, c_out, c_pz, 0.0_dp, m_hc)
            CALL cp_fm_symm('L', 'U', nao, nmo_not_converged, 1.0_dp, s_fm, c_pz, 0.0_dp, m_sc)

            CALL cp_fm_release(c_out)
            CALL cp_fm_release(c_conv)
            CALL cp_fm_release(s_fm)
         ELSE
            CALL cp_dbcsr_sm_fm_multiply(matrix_h, c_pz, m_hc, nmo)
            CALL cp_dbcsr_sm_fm_multiply(matrix_s, c_pz, m_sc, nmo)
         END IF

         ! compute ZSZ
         CALL parallel_gemm('T', 'N', nmat, nmat, nao, 1.0_dp, c_pz, m_sc, 0.0_dp, m_tmp)
         CALL cp_fm_to_fm_submat(m_tmp, s_block, nmat, nmat, 1, 1, 1 + nmat, 1 + nmat)
         ! compute ZHZ
         CALL parallel_gemm('T', 'N', nmat, nmat, nao, 1.0_dp, c_pz, m_hc, 0.0_dp, m_tmp)
         CALL cp_fm_to_fm_submat(m_tmp, h_block, nmat, nmat, 1, 1, 1 + nmat, 1 + nmat)

         CALL cp_fm_release(m_sc)

         ! solution of the reduced eigenvalues problem
         CALL reduce_extended_space(s_block, h_block, v_block, w_block, evals, nmat2)

         ! extract egenvectors
         CALL cp_fm_to_fm_submat(v_block, m_tmp, nmat, nmat, 1, 1, 1, 1)
         CALL parallel_gemm('N', 'N', nao, nmat, nmat, 1.0_dp, c_notconv, m_tmp, 0.0_dp, m_hc)
         CALL cp_fm_to_fm_submat(v_block, m_tmp, nmat, nmat, 1 + nmat, 1, 1, 1)
         CALL parallel_gemm('N', 'N', nao, nmat, nmat, 1.0_dp, c_pz, m_tmp, 1.0_dp, m_hc)

         CALL cp_fm_release(m_tmp)

         CALL cp_fm_release(c_notconv)
         CALL cp_fm_release(s_block)
         CALL cp_fm_release(h_block)
         CALL cp_fm_release(w_block)
         CALL cp_fm_release(v_block)

         IF (nmo_converged > 0) THEN
            CALL cp_fm_release(c_z)
            CALL cp_fm_release(c_pz)
            DEALLOCATE (c_z, c_pz)
            jj = 1
            DO j = 1, nset_not_conv
               i_first = inotconv_set(j, 1)
               i_last = inotconv_set(j, 2)
               n = i_last - i_first + 1
               CALL cp_fm_to_fm_submat(m_hc, mo_coeff, nao, n, 1, jj, 1, i_first)
               eigenvalues(i_first:i_last) = evals(jj:jj + n - 1)
               jj = jj + n
            END DO
            DEALLOCATE (iconv_set)
            DEALLOCATE (inotconv_set)
         ELSE
            CALL cp_fm_to_fm(m_hc, mo_coeff)
            eigenvalues(1:nmo) = evals(1:nmo)
         END IF
         DEALLOCATE (evals)

         CALL cp_fm_release(m_hc)

         CALL copy_fm_to_dbcsr(mo_coeff, mo_coeff_b) !fm->dbcsr

         t2 = m_walltime()
         IF (output_unit > 0) THEN
            WRITE (output_unit, '(T16,I5,T24,I6,T33,E12.4,2x,E12.4,T60,F8.3)') &
               iter, nmo_converged, max_norm, min_norm, t2 - t1
         END IF
         t1 = m_walltime()

      END DO ! iter

      DEALLOCATE (iconv)
      DEALLOCATE (inotconv)
      DEALLOCATE (ritz_coeff)
      DEALLOCATE (vnorm)

      CALL timestop(handle)
   END SUBROUTINE generate_extended_space

! **************************************************************************************************
!> \brief ...
!> \param bdav_env ...
!> \param mo_set ...
!> \param matrix_h ...
!> \param matrix_s ...
!> \param output_unit ...
!> \param preconditioner ...
! **************************************************************************************************
   SUBROUTINE generate_extended_space_sparse(bdav_env, mo_set, matrix_h, matrix_s, output_unit, &
                                             preconditioner)

      TYPE(davidson_type)                                :: bdav_env
      TYPE(mo_set_type), INTENT(IN)                      :: mo_set
      TYPE(dbcsr_type), POINTER                          :: matrix_h, matrix_s
      INTEGER, INTENT(IN)                                :: output_unit
      TYPE(preconditioner_type), OPTIONAL, POINTER       :: preconditioner

      CHARACTER(len=*), PARAMETER :: routineN = 'generate_extended_space_sparse'

      INTEGER :: col_offset, handle, homo, i_first, i_last, imo, iteration, j, jj, k, max_iter, n, &
         nao, nmat, nmat2, nmo, nmo_converged, nmo_not_converged, nset, nset_conv, nset_not_conv
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: iconv, inotconv
      INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: iconv_set, inotconv_set
      LOGICAL                                            :: converged, do_apply_preconditioner
      REAL(dp)                                           :: lambda, max_norm, min_norm, t1, t2
      REAL(dp), ALLOCATABLE, DIMENSION(:)                :: eig_not_conv, evals, ritz_coeff, vnorm
      REAL(dp), DIMENSION(:), POINTER                    :: eigenvalues
      REAL(dp), DIMENSION(:, :), POINTER                 :: block
      TYPE(cp_fm_struct_type), POINTER                   :: fm_struct_tmp
      TYPE(cp_fm_type)                                   :: h_block, matrix_mm_fm, matrix_mmt_fm, &
                                                            matrix_nm_fm, matrix_z_fm, mo_conv_fm, &
                                                            s_block, v_block, w_block
      TYPE(cp_fm_type), POINTER                          :: mo_coeff, mo_notconv_fm
      TYPE(dbcsr_iterator_type)                          :: iter
      TYPE(dbcsr_type), POINTER                          :: c_out, matrix_hc, matrix_mm, matrix_pz, &
                                                            matrix_sc, matrix_z, mo_coeff_b, &
                                                            mo_conv, mo_notconv, smo_conv
      TYPE(mp_comm_type)                                 :: group

      CALL timeset(routineN, handle)

      do_apply_preconditioner = .FALSE.
      IF (PRESENT(preconditioner)) do_apply_preconditioner = .TRUE.

      NULLIFY (mo_coeff, mo_coeff_b, matrix_hc, matrix_sc, matrix_z, matrix_pz, matrix_mm)
      NULLIFY (mo_notconv_fm, mo_conv, mo_notconv, smo_conv, c_out)
      NULLIFY (fm_struct_tmp)
      CALL get_mo_set(mo_set=mo_set, mo_coeff=mo_coeff, mo_coeff_b=mo_coeff_b, &
                      eigenvalues=eigenvalues, homo=homo, nao=nao, nmo=nmo)
      IF (do_apply_preconditioner) THEN
         max_iter = bdav_env%max_iter
      ELSE
         max_iter = 1
      END IF

      t1 = m_walltime()
      IF (output_unit > 0) THEN
         WRITE (output_unit, "(T15,A,T23,A,T36,A,T49,A,T60,A,/,T8,A)") &
            " Cycle ", " conv. MOS ", " B2MAX ", " B2MIN ", " Time", REPEAT("-", 60)
      END IF

      ! Allocate array for Ritz values
      ALLOCATE (ritz_coeff(nmo))
      ALLOCATE (iconv(nmo))
      ALLOCATE (inotconv(nmo))
      ALLOCATE (vnorm(nmo))

      converged = .FALSE.
      DO iteration = 1, max_iter
         NULLIFY (c_out, mo_conv, mo_notconv_fm, mo_notconv)
         ! Prepare HC and SC, using mo_coeff_b (sparse), these are still sparse
         CALL dbcsr_init_p(matrix_hc)
         CALL dbcsr_create(matrix_hc, template=mo_coeff_b, &
                           name="matrix_hc", &
                           matrix_type=dbcsr_type_no_symmetry)
         CALL dbcsr_init_p(matrix_sc)
         CALL dbcsr_create(matrix_sc, template=mo_coeff_b, &
                           name="matrix_sc", &
                           matrix_type=dbcsr_type_no_symmetry)

         CALL dbcsr_get_info(mo_coeff_b, nfullrows_total=n, nfullcols_total=k, group=group)
         CALL dbcsr_multiply('n', 'n', 1.0_dp, matrix_h, mo_coeff_b, 0.0_dp, matrix_hc, last_column=k)
         CALL dbcsr_multiply('n', 'n', 1.0_dp, matrix_s, mo_coeff_b, 0.0_dp, matrix_sc, last_column=k)

         ! compute Ritz values
         ritz_coeff = 0.0_dp
         ! Allocate Sparse matrices: nmoxnmo
         ! matrix_mm

         CALL dbcsr_init_p(matrix_mm)
         CALL cp_dbcsr_m_by_n_from_template(matrix_mm, template=matrix_s, m=nmo, n=nmo, &
                                            sym=dbcsr_type_no_symmetry)

         CALL dbcsr_multiply('t', 'n', 1.0_dp, mo_coeff_b, matrix_hc, 0.0_dp, matrix_mm, last_column=k)
         CALL dbcsr_get_diag(matrix_mm, ritz_coeff)
         CALL mo_coeff%matrix_struct%para_env%sum(ritz_coeff)

         ! extended subspace P Z = P [H - theta S]C  this ia another matrix of type and size as mo_coeff_b
         CALL dbcsr_init_p(matrix_z)
         CALL dbcsr_create(matrix_z, template=mo_coeff_b, &
                           name="matrix_z", &
                           matrix_type=dbcsr_type_no_symmetry)
         CALL dbcsr_copy(matrix_z, matrix_sc)
         CALL dbcsr_scale_by_vector(matrix_z, ritz_coeff, side='right')
         CALL dbcsr_add(matrix_z, matrix_hc, -1.0_dp, 1.0_dp)

         ! Compute the column norms of matrix_z.
         vnorm = 0.0_dp
         CALL dbcsr_iterator_start(iter, matrix_z)
         DO WHILE (dbcsr_iterator_blocks_left(iter))
            CALL dbcsr_iterator_next_block(iter, block=block, col_offset=col_offset)
            DO j = 1, SIZE(block, 2)
               vnorm(col_offset + j - 1) = vnorm(col_offset + j - 1) + SUM(block(:, j)**2)
            END DO
         END DO
         CALL dbcsr_iterator_stop(iter)
         CALL group%sum(vnorm)
         vnorm = SQRT(vnorm)

         ! Check for converged eigenvectors
         nmo_converged = 0
         nmo_not_converged = 0
         max_norm = 0.0_dp
         min_norm = 1.e10_dp
         DO imo = 1, nmo
            max_norm = MAX(max_norm, vnorm(imo))
            min_norm = MIN(min_norm, vnorm(imo))
         END DO
         iconv = 0
         inotconv = 0

         DO imo = 1, nmo
            IF (vnorm(imo) <= bdav_env%eps_iter) THEN
               nmo_converged = nmo_converged + 1
               iconv(nmo_converged) = imo
            ELSE
               nmo_not_converged = nmo_not_converged + 1
               inotconv(nmo_not_converged) = imo
            END IF
         END DO

         IF (nmo_converged > 0) THEN
            ALLOCATE (iconv_set(nmo_converged, 2))
            ALLOCATE (inotconv_set(nmo_not_converged, 2))
            i_last = iconv(1)
            nset = 0
            DO j = 1, nmo_converged
               imo = iconv(j)

               IF (imo == i_last + 1) THEN
                  i_last = imo
                  iconv_set(nset, 2) = imo
               ELSE
                  i_last = imo
                  nset = nset + 1
                  iconv_set(nset, 1) = imo
                  iconv_set(nset, 2) = imo
               END IF
            END DO
            nset_conv = nset

            i_last = inotconv(1)
            nset = 0
            DO j = 1, nmo_not_converged
               imo = inotconv(j)

               IF (imo == i_last + 1) THEN
                  i_last = imo
                  inotconv_set(nset, 2) = imo
               ELSE
                  i_last = imo
                  nset = nset + 1
                  inotconv_set(nset, 1) = imo
                  inotconv_set(nset, 2) = imo
               END IF
            END DO
            nset_not_conv = nset

            CALL dbcsr_release_p(matrix_hc)
            CALL dbcsr_release_p(matrix_sc)
            CALL dbcsr_release_p(matrix_z)
            CALL dbcsr_release_p(matrix_mm)
         END IF

         IF (REAL(nmo_converged, dp)/REAL(nmo, dp) > bdav_env%conv_percent) THEN
            DEALLOCATE (iconv_set)

            DEALLOCATE (inotconv_set)

            converged = .TRUE.
            t2 = m_walltime()
            IF (output_unit > 0) THEN
               WRITE (output_unit, '(T16,I5,T24,I6,T33,E12.4,2x,E12.4,T60,F8.3)') &
                  iteration, nmo_converged, max_norm, min_norm, t2 - t1

               WRITE (output_unit, *) " Reached convergence in ", iteration, &
                  " Davidson iterations"
            END IF

            EXIT
         END IF

         IF (nmo_converged > 0) THEN

            !allocate mo_conv_fm
            CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=nao, ncol_global=nmo_converged, &
                                     context=mo_coeff%matrix_struct%context, &
                                     para_env=mo_coeff%matrix_struct%para_env)
            CALL cp_fm_create(mo_conv_fm, fm_struct_tmp, name="mo_conv_fm")

            CALL cp_fm_struct_release(fm_struct_tmp)

            ! extract mo_conv from mo_coeff full matrix
            jj = 1
            DO j = 1, nset_conv
               i_first = iconv_set(j, 1)
               i_last = iconv_set(j, 2)
               n = i_last - i_first + 1
               CALL cp_fm_to_fm_submat(mo_coeff, mo_conv_fm, nao, n, 1, i_first, 1, jj)
               jj = jj + n
            END DO

            ! allocate c_out sparse matrix, to project out the converged MOS
            CALL dbcsr_init_p(c_out)
            CALL dbcsr_create(c_out, template=matrix_s, &
                              name="c_out", &
                              matrix_type=dbcsr_type_symmetric)

            ! allocate mo_conv sparse
            CALL dbcsr_init_p(mo_conv)
            CALL cp_dbcsr_m_by_n_from_row_template(mo_conv, template=matrix_s, n=nmo_converged, &
                                                   sym=dbcsr_type_no_symmetry)

            CALL dbcsr_init_p(smo_conv)
            CALL cp_dbcsr_m_by_n_from_row_template(smo_conv, template=matrix_s, n=nmo_converged, &
                                                   sym=dbcsr_type_no_symmetry)

            CALL copy_fm_to_dbcsr(mo_conv_fm, mo_conv) !fm->dbcsr

            CALL dbcsr_multiply('n', 'n', 1.0_dp, matrix_s, mo_conv, 0.0_dp, smo_conv, last_column=nmo_converged)
            CALL dbcsr_multiply('n', 't', 1.0_dp, smo_conv, smo_conv, 0.0_dp, c_out, last_column=nao)
            ! project c_out out of H
            lambda = 100.0_dp*ABS(eigenvalues(homo))
            CALL dbcsr_add(c_out, matrix_h, lambda, 1.0_dp)

            CALL dbcsr_release_p(mo_conv)
            CALL dbcsr_release_p(smo_conv)
            CALL cp_fm_release(mo_conv_fm)

            !allocate c_notconv_fm
            CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=nao, ncol_global=nmo_not_converged, &
                                     context=mo_coeff%matrix_struct%context, &
                                     para_env=mo_coeff%matrix_struct%para_env)
            ALLOCATE (mo_notconv_fm)
            CALL cp_fm_create(mo_notconv_fm, fm_struct_tmp, name="mo_notconv_fm")
            CALL cp_fm_struct_release(fm_struct_tmp)

            ! extract mo_notconv from mo_coeff full matrix
            ALLOCATE (eig_not_conv(nmo_not_converged))

            jj = 1
            DO j = 1, nset_not_conv
               i_first = inotconv_set(j, 1)
               i_last = inotconv_set(j, 2)
               n = i_last - i_first + 1
               CALL cp_fm_to_fm_submat(mo_coeff, mo_notconv_fm, nao, n, 1, i_first, 1, jj)
               eig_not_conv(jj:jj + n - 1) = ritz_coeff(i_first:i_last)
               jj = jj + n
            END DO

            ! allocate mo_conv sparse
            CALL dbcsr_init_p(mo_notconv)
            CALL cp_dbcsr_m_by_n_from_row_template(mo_notconv, template=matrix_s, n=nmo_not_converged, &
                                                   sym=dbcsr_type_no_symmetry)

            CALL dbcsr_init_p(matrix_hc)
            CALL cp_dbcsr_m_by_n_from_row_template(matrix_hc, template=matrix_s, n=nmo_not_converged, &
                                                   sym=dbcsr_type_no_symmetry)

            CALL dbcsr_init_p(matrix_sc)
            CALL cp_dbcsr_m_by_n_from_row_template(matrix_sc, template=matrix_s, n=nmo_not_converged, &
                                                   sym=dbcsr_type_no_symmetry)

            CALL dbcsr_init_p(matrix_z)
            CALL cp_dbcsr_m_by_n_from_row_template(matrix_z, template=matrix_s, n=nmo_not_converged, &
                                                   sym=dbcsr_type_no_symmetry)

            CALL copy_fm_to_dbcsr(mo_notconv_fm, mo_notconv) !fm->dbcsr

            CALL dbcsr_multiply('n', 'n', 1.0_dp, c_out, mo_notconv, 0.0_dp, matrix_hc, &
                                last_column=nmo_not_converged)
            CALL dbcsr_multiply('n', 'n', 1.0_dp, matrix_s, mo_notconv, 0.0_dp, matrix_sc, &
                                last_column=nmo_not_converged)

            CALL dbcsr_copy(matrix_z, matrix_sc)
            CALL dbcsr_scale_by_vector(matrix_z, eig_not_conv, side='right')
            CALL dbcsr_add(matrix_z, matrix_hc, -1.0_dp, 1.0_dp)

            DEALLOCATE (eig_not_conv)

            ! matrix_mm
            CALL dbcsr_init_p(matrix_mm)
            CALL cp_dbcsr_m_by_n_from_template(matrix_mm, template=matrix_s, m=nmo_not_converged, n=nmo_not_converged, &
                                               sym=dbcsr_type_no_symmetry)

            CALL dbcsr_multiply('t', 'n', 1.0_dp, mo_notconv, matrix_hc, 0.0_dp, matrix_mm, &
                                last_column=nmo_not_converged)

         ELSE
            mo_notconv => mo_coeff_b
            mo_notconv_fm => mo_coeff
            c_out => matrix_h
         END IF

         ! allocate matrix_pz using as template matrix_z
         CALL dbcsr_init_p(matrix_pz)
         CALL dbcsr_create(matrix_pz, template=matrix_z, &
                           name="matrix_pz", &
                           matrix_type=dbcsr_type_no_symmetry)

         IF (do_apply_preconditioner) THEN
            IF (preconditioner%in_use /= 0) THEN
               CALL apply_preconditioner(preconditioner, matrix_z, matrix_pz)
            ELSE
               CALL dbcsr_copy(matrix_pz, matrix_z)
            END IF
         ELSE
            CALL dbcsr_copy(matrix_pz, matrix_z)
         END IF

         !allocate NMOxNMO  full matrices
         nmat = nmo_not_converged
         CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=nmat, ncol_global=nmat, &
                                  context=mo_coeff%matrix_struct%context, &
                                  para_env=mo_coeff%matrix_struct%para_env)
         CALL cp_fm_create(matrix_mm_fm, fm_struct_tmp, name="m_tmp_mxm")
         CALL cp_fm_create(matrix_mmt_fm, fm_struct_tmp, name="mt_tmp_mxm")
         CALL cp_fm_struct_release(fm_struct_tmp)

         !allocate 2NMOx2NMO full matrices
         nmat2 = 2*nmo_not_converged
         CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=nmat2, ncol_global=nmat2, &
                                  context=mo_coeff%matrix_struct%context, &
                                  para_env=mo_coeff%matrix_struct%para_env)

         CALL cp_fm_create(s_block, fm_struct_tmp, name="sb")
         CALL cp_fm_create(h_block, fm_struct_tmp, name="hb")
         CALL cp_fm_create(v_block, fm_struct_tmp, name="vb")
         CALL cp_fm_create(w_block, fm_struct_tmp, name="wb")
         ALLOCATE (evals(nmat2))
         CALL cp_fm_struct_release(fm_struct_tmp)

         ! compute CSC
         CALL cp_fm_set_all(s_block, 0.0_dp, 1.0_dp)
         ! compute CHC
         CALL copy_dbcsr_to_fm(matrix_mm, matrix_mm_fm)
         CALL cp_fm_to_fm_submat(matrix_mm_fm, h_block, nmat, nmat, 1, 1, 1, 1)

         ! compute the bottom left  ZSC (top right is transpose)
         CALL dbcsr_multiply('t', 'n', 1.0_dp, matrix_pz, matrix_sc, 0.0_dp, matrix_mm, last_column=nmat)
         !  set the bottom left part of S[C,Z] block matrix  ZSC
         !copy sparse to full
         CALL copy_dbcsr_to_fm(matrix_mm, matrix_mm_fm)
         CALL cp_fm_to_fm_submat(matrix_mm_fm, s_block, nmat, nmat, 1, 1, 1 + nmat, 1)
         CALL cp_fm_transpose(matrix_mm_fm, matrix_mmt_fm)
         CALL cp_fm_to_fm_submat(matrix_mmt_fm, s_block, nmat, nmat, 1, 1, 1, 1 + nmat)

         ! compute the bottom left  ZHC (top right is transpose)
         CALL dbcsr_multiply('t', 'n', 1.0_dp, matrix_pz, matrix_hc, 0.0_dp, matrix_mm, last_column=nmat)
         ! set the bottom left part of S[C,Z] block matrix  ZHC
         CALL copy_dbcsr_to_fm(matrix_mm, matrix_mm_fm)
         CALL cp_fm_to_fm_submat(matrix_mm_fm, h_block, nmat, nmat, 1, 1, 1 + nmat, 1)
         CALL cp_fm_transpose(matrix_mm_fm, matrix_mmt_fm)
         CALL cp_fm_to_fm_submat(matrix_mmt_fm, h_block, nmat, nmat, 1, 1, 1, 1 + nmat)

         CALL cp_fm_release(matrix_mmt_fm)

         ! (reuse matrix_sc and matrix_hc to computr HZ and SZ)
         CALL dbcsr_get_info(matrix_pz, nfullrows_total=n, nfullcols_total=k)
         CALL dbcsr_multiply('n', 'n', 1.0_dp, c_out, matrix_pz, 0.0_dp, matrix_hc, last_column=k)
         CALL dbcsr_multiply('n', 'n', 1.0_dp, matrix_s, matrix_pz, 0.0_dp, matrix_sc, last_column=k)

         ! compute the bottom right  ZSZ
         CALL dbcsr_multiply('t', 'n', 1.0_dp, matrix_pz, matrix_sc, 0.0_dp, matrix_mm, last_column=k)
         ! set the bottom right part of S[C,Z] block matrix  ZSZ
         CALL copy_dbcsr_to_fm(matrix_mm, matrix_mm_fm)
         CALL cp_fm_to_fm_submat(matrix_mm_fm, s_block, nmat, nmat, 1, 1, 1 + nmat, 1 + nmat)

         ! compute the bottom right  ZHZ
         CALL dbcsr_multiply('t', 'n', 1.0_dp, matrix_pz, matrix_hc, 0.0_dp, matrix_mm, last_column=k)
         ! set the bottom right part of H[C,Z] block matrix  ZHZ
         CALL copy_dbcsr_to_fm(matrix_mm, matrix_mm_fm)
         CALL cp_fm_to_fm_submat(matrix_mm_fm, h_block, nmat, nmat, 1, 1, 1 + nmat, 1 + nmat)

         CALL dbcsr_release_p(matrix_mm)
         CALL dbcsr_release_p(matrix_sc)
         CALL dbcsr_release_p(matrix_hc)

         CALL reduce_extended_space(s_block, h_block, v_block, w_block, evals, nmat2)

         ! allocate two (nao x nmat) full matrix
         CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=nao, ncol_global=nmat, &
                                  context=mo_coeff%matrix_struct%context, &
                                  para_env=mo_coeff%matrix_struct%para_env)
         CALL cp_fm_create(matrix_nm_fm, fm_struct_tmp, name="m_nxm")
         CALL cp_fm_create(matrix_z_fm, fm_struct_tmp, name="m_nxm")
         CALL cp_fm_struct_release(fm_struct_tmp)

         CALL copy_dbcsr_to_fm(matrix_pz, matrix_z_fm)
         ! extract egenvectors
         CALL cp_fm_to_fm_submat(v_block, matrix_mm_fm, nmat, nmat, 1, 1, 1, 1)
         CALL parallel_gemm('N', 'N', nao, nmat, nmat, 1.0_dp, mo_notconv_fm, matrix_mm_fm, 0.0_dp, matrix_nm_fm)
         CALL cp_fm_to_fm_submat(v_block, matrix_mm_fm, nmat, nmat, 1 + nmat, 1, 1, 1)
         CALL parallel_gemm('N', 'N', nao, nmat, nmat, 1.0_dp, matrix_z_fm, matrix_mm_fm, 1.0_dp, matrix_nm_fm)

         CALL dbcsr_release_p(matrix_z)
         CALL dbcsr_release_p(matrix_pz)
         CALL cp_fm_release(matrix_z_fm)
         CALL cp_fm_release(s_block)
         CALL cp_fm_release(h_block)
         CALL cp_fm_release(w_block)
         CALL cp_fm_release(v_block)
         CALL cp_fm_release(matrix_mm_fm)

         ! in case some vector are already converged only a subset of vectors are copied in the MOS
         IF (nmo_converged > 0) THEN
            jj = 1
            DO j = 1, nset_not_conv
               i_first = inotconv_set(j, 1)
               i_last = inotconv_set(j, 2)
               n = i_last - i_first + 1
               CALL cp_fm_to_fm_submat(matrix_nm_fm, mo_coeff, nao, n, 1, jj, 1, i_first)
               eigenvalues(i_first:i_last) = evals(jj:jj + n - 1)
               jj = jj + n
            END DO
            DEALLOCATE (iconv_set)
            DEALLOCATE (inotconv_set)

            CALL dbcsr_release_p(mo_notconv)
            CALL dbcsr_release_p(c_out)
            CALL cp_fm_release(mo_notconv_fm)
            DEALLOCATE (mo_notconv_fm)
         ELSE
            CALL cp_fm_to_fm(matrix_nm_fm, mo_coeff)
            eigenvalues(1:nmo) = evals(1:nmo)
         END IF
         DEALLOCATE (evals)

         CALL cp_fm_release(matrix_nm_fm)
         CALL copy_fm_to_dbcsr(mo_coeff, mo_coeff_b) !fm->dbcsr

         t2 = m_walltime()
         IF (output_unit > 0) THEN
            WRITE (output_unit, '(T16,I5,T24,I6,T33,E12.4,2x,E12.4,T60,F8.3)') &
               iteration, nmo_converged, max_norm, min_norm, t2 - t1
         END IF
         t1 = m_walltime()

      END DO ! iteration

      DEALLOCATE (ritz_coeff)
      DEALLOCATE (iconv)
      DEALLOCATE (inotconv)
      DEALLOCATE (vnorm)

      CALL timestop(handle)

   END SUBROUTINE generate_extended_space_sparse

! **************************************************************************************************

! **************************************************************************************************
!> \brief ...
!> \param s_block ...
!> \param h_block ...
!> \param v_block ...
!> \param w_block ...
!> \param evals ...
!> \param ndim ...
! **************************************************************************************************
   SUBROUTINE reduce_extended_space(s_block, h_block, v_block, w_block, evals, ndim)

      TYPE(cp_fm_type), INTENT(IN)                       :: s_block, h_block, v_block, w_block
      REAL(dp), DIMENSION(:)                             :: evals
      INTEGER                                            :: ndim

      CHARACTER(len=*), PARAMETER :: routineN = 'reduce_extended_space'

      INTEGER                                            :: handle, info

      CALL timeset(routineN, handle)

      CALL cp_fm_to_fm(s_block, w_block)
      CALL cp_fm_cholesky_decompose(s_block, info_out=info)
      IF (info == 0) THEN
         CALL cp_fm_triangular_invert(s_block)
         CALL cp_fm_cholesky_restore(H_block, ndim, S_block, w_block, "MULTIPLY", pos="RIGHT")
         CALL cp_fm_cholesky_restore(w_block, ndim, S_block, H_block, "MULTIPLY", pos="LEFT", transa="T")
         CALL choose_eigv_solver(H_block, w_block, evals)
         CALL cp_fm_cholesky_restore(w_block, ndim, S_block, v_block, "MULTIPLY")
      ELSE
! S^(-1/2)
         CALL cp_fm_power(w_block, s_block, -0.5_dp, 1.0E-5_dp, info)
         CALL cp_fm_to_fm(w_block, s_block)
         CALL parallel_gemm('N', 'N', ndim, ndim, ndim, 1.0_dp, H_block, s_block, 0.0_dp, w_block)
         CALL parallel_gemm('N', 'N', ndim, ndim, ndim, 1.0_dp, s_block, w_block, 0.0_dp, H_block)
         CALL choose_eigv_solver(H_block, w_block, evals)
         CALL parallel_gemm('N', 'N', ndim, ndim, ndim, 1.0_dp, s_block, w_block, 0.0_dp, v_block)
      END IF

      CALL timestop(handle)

   END SUBROUTINE reduce_extended_space

END MODULE qs_scf_block_davidson
